{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# ResNet20 on CIFAR-10: Pruning\n",
        "\n",
        ".. tip:: It would take about 3 hours on Google Colab but can be run as fast as in 1 hour if you use a better GPU. You can expect slightly different accuracies than reported below depending on the system you run this notebook in. The purpose of this notebook is to demonstrate the workflow of pruning using Model Optimizer and not to achieve the best accuracy.\n",
        "\n",
        "In this tutorial, we will use Model Optimizer to make the ResNet model faster for our target deployment constraints\n",
        "using pruning without sacrificing much accuracy!\n",
        "\n",
        "By the end of this tutorial, you will:\n",
        "\n",
        "* Understand how to use Model Optimizer to prune a user-provided model to the best performing subnet architecture fitting your target deployment constraints.\n",
        "* Save and restore your pruned model for downstream tasks like fine-tuning and inference.\n",
        "\n",
        "All of this with just a few lines of code! Yes, it's that simple!\n",
        "\n",
        "Let's first install `Model Optimizer` following the [installation steps](https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/2_installation.html)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "%pip install \"nvidia-modelopt[torch]\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "BjwUMv7xkYtS"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "import os\n",
        "import random\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "from torchvision.models.resnet import BasicBlock\n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "seed = 123\n",
        "random.seed(seed)\n",
        "np.random.seed(seed)\n",
        "torch.manual_seed(seed)\n",
        "\n",
        "device = torch.device(\"cuda\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## CIFAR-10 Dataset for Image Classification\n",
        "\n",
        "For this tutorial, we will be working with the well known [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset for image classification. The dataset consists of 60k 32x32 images from 10 classes split into 50k training and 10k testing images. We will further take 5k randomly out from the training set to make it our validation set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [],
      "source": [
        "def get_cifar10_dataloaders(train_batch_size: int):\n",
        "    \"\"\"Return Train-Val-Test data loaders for the CIFAR-10 dataset.\"\"\"\n",
        "    np.random.seed(seed)\n",
        "\n",
        "    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])\n",
        "\n",
        "    # Split Train dataset into Train-Val datasets\n",
        "    train_dataset = torchvision.datasets.CIFAR10(\n",
        "        root=\"./data\",\n",
        "        train=True,\n",
        "        transform=transforms.Compose(\n",
        "            [\n",
        "                transforms.ToTensor(),\n",
        "                transforms.RandomHorizontalFlip(),\n",
        "                transforms.RandomCrop(32, 4),\n",
        "                normalize,\n",
        "            ]\n",
        "        ),\n",
        "        download=True,\n",
        "    )\n",
        "\n",
        "    n_trainval = len(train_dataset)\n",
        "    n_train = int(n_trainval * 0.9)\n",
        "    ids = np.arange(n_trainval)\n",
        "    np.random.shuffle(ids)\n",
        "    train_ids, val_ids = ids[:n_train], ids[n_train:]\n",
        "\n",
        "    train_dataset.data = train_dataset.data[train_ids]\n",
        "    train_dataset.targets = np.array(train_dataset.targets)[train_ids]\n",
        "\n",
        "    val_dataset = torchvision.datasets.CIFAR10(\n",
        "        root=\"./data\",\n",
        "        train=True,\n",
        "        transform=transforms.Compose([transforms.ToTensor(), normalize]),\n",
        "        download=True,\n",
        "    )\n",
        "    val_dataset.data = val_dataset.data[val_ids]\n",
        "    val_dataset.targets = np.array(val_dataset.targets)[val_ids]\n",
        "\n",
        "    test_dataset = torchvision.datasets.CIFAR10(\n",
        "        root=\"./data\",\n",
        "        train=False,\n",
        "        transform=val_dataset.transform,\n",
        "        download=True,\n",
        "    )\n",
        "\n",
        "    num_workers = min(8, os.cpu_count())\n",
        "    train_loader = torch.utils.data.DataLoader(\n",
        "        train_dataset, train_batch_size, num_workers=num_workers, pin_memory=True, shuffle=True\n",
        "    )\n",
        "    val_loader = torch.utils.data.DataLoader(\n",
        "        val_dataset, batch_size=1024, num_workers=num_workers, pin_memory=True\n",
        "    )\n",
        "    test_loader = torch.utils.data.DataLoader(\n",
        "        test_dataset, batch_size=1024, num_workers=num_workers, pin_memory=True\n",
        "    )\n",
        "    print(f\"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}\")\n",
        "\n",
        "    return train_loader, val_loader, test_loader"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## ResNet for CIFAR dataset\n",
        "\n",
        "We will be working with the ResNet variants for CIFAR dataset, namely ResNet-20 and ResNet-32 since these are very small models to train. You can find more details about these models in [this paper](https://arxiv.org/abs/1512.03385).\n",
        "Below is an example of a regular PyTorch model without anything new."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Setting up the model\n",
        "\n",
        "We first set up and add some helper functions for training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "def _weights_init(m):\n",
        "    if isinstance(m, (nn.Linear, nn.Conv2d)):\n",
        "        torch.nn.init.kaiming_normal_(m.weight)\n",
        "\n",
        "\n",
        "class LambdaLayer(nn.Module):\n",
        "    def __init__(self, lambd):\n",
        "        super().__init__()\n",
        "        self.lambd = lambd\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.lambd(x)\n",
        "\n",
        "\n",
        "class ResNet(nn.Module):\n",
        "    def __init__(self, num_blocks, num_classes=10):\n",
        "        super().__init__()\n",
        "        self.in_planes = 16\n",
        "        self.layers = nn.Sequential(\n",
        "            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),\n",
        "            nn.BatchNorm2d(16),\n",
        "            nn.ReLU(),\n",
        "            self._make_layer(16, num_blocks, stride=1),\n",
        "            self._make_layer(32, num_blocks, stride=2),\n",
        "            self._make_layer(64, num_blocks, stride=2),\n",
        "            nn.AdaptiveAvgPool2d((1, 1)),\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(64, num_classes),\n",
        "        )\n",
        "        self.apply(_weights_init)\n",
        "\n",
        "    def _make_layer(self, planes, num_blocks, stride):\n",
        "        strides = [stride] + [1] * (num_blocks - 1)\n",
        "        layers = []\n",
        "        for _stride in strides:\n",
        "            downsample = None\n",
        "            if _stride != 1 or self.in_planes != planes:\n",
        "                downsample = LambdaLayer(\n",
        "                    lambda x: F.pad(\n",
        "                        x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), \"constant\", 0\n",
        "                    )\n",
        "                )\n",
        "            layers.append(BasicBlock(self.in_planes, planes, _stride, downsample))\n",
        "            self.in_planes = planes\n",
        "        return nn.Sequential(*layers)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.layers(x)\n",
        "\n",
        "\n",
        "def resnet20(ckpt=None):\n",
        "    model = ResNet(num_blocks=3).to(device)\n",
        "    if ckpt is not None:\n",
        "        model.load_state_dict(torch.load(ckpt, device))\n",
        "    return model\n",
        "\n",
        "\n",
        "def resnet32(ckpt=None):\n",
        "    model = ResNet(num_blocks=5).to(device)\n",
        "    if ckpt is not None:\n",
        "        model.load_state_dict(torch.load(ckpt, device))\n",
        "    return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [],
      "source": [
        "class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):\n",
        "    def __init__(\n",
        "        self,\n",
        "        optimizer: torch.optim.Optimizer,\n",
        "        warmup_steps: int,\n",
        "        decay_steps: int,\n",
        "        warmup_lr: float = 0.0,\n",
        "        last_epoch: int = -1,\n",
        "    ) -> None:\n",
        "        self.warmup_steps = warmup_steps\n",
        "        self.warmup_lr = warmup_lr\n",
        "        self.decay_steps = decay_steps\n",
        "        super().__init__(optimizer, last_epoch)\n",
        "\n",
        "    def get_lr(self):\n",
        "        if self.last_epoch < self.warmup_steps:\n",
        "            return [\n",
        "                (base_lr - self.warmup_lr) * self.last_epoch / self.warmup_steps + self.warmup_lr\n",
        "                for base_lr in self.base_lrs\n",
        "            ]\n",
        "        else:\n",
        "            current_steps = self.last_epoch - self.warmup_steps\n",
        "            return [\n",
        "                0.5 * base_lr * (1 + math.cos(math.pi * current_steps / self.decay_steps))\n",
        "                for base_lr in self.base_lrs\n",
        "            ]\n",
        "\n",
        "\n",
        "def get_optimizer_scheduler(model, lr, weight_decay, warmup_steps, decay_steps):\n",
        "    optimizer = torch.optim.SGD(\n",
        "        filter(lambda p: p.requires_grad, model.parameters()),\n",
        "        lr,\n",
        "        momentum=0.9,\n",
        "        weight_decay=weight_decay,\n",
        "    )\n",
        "    lr_scheduler = CosineLRwithWarmup(optimizer, warmup_steps, decay_steps)\n",
        "    return optimizer, lr_scheduler\n",
        "\n",
        "\n",
        "def train_one_epoch(model, train_loader, loss_fn, optimizer, lr_scheduler):\n",
        "    \"\"\"Train the given model for 1 epoch.\"\"\"\n",
        "    model.train()\n",
        "    epoch_loss = 0.0\n",
        "    for imgs, labels in train_loader:\n",
        "        output = model(imgs.to(device))\n",
        "        loss = loss_fn(model, output, labels.to(device))\n",
        "        epoch_loss += loss.item()\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        lr_scheduler.step()\n",
        "\n",
        "    epoch_loss /= len(train_loader.dataset)\n",
        "    return epoch_loss\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate(model, test_loader):\n",
        "    \"\"\"Evaluate the model on the given test_loader and return accuracy percentage.\"\"\"\n",
        "    model.eval()\n",
        "    correct = total = 0.0\n",
        "    for imgs, labels in test_loader:\n",
        "        output = model(imgs.to(device))\n",
        "        predicted = output.argmax(dim=1).detach().cpu()\n",
        "        correct += torch.sum(labels == predicted).item()\n",
        "        total += len(labels)\n",
        "\n",
        "    accuracy = 100 * correct / total\n",
        "    return accuracy\n",
        "\n",
        "\n",
        "def loss_fn_default(model, output, labels):\n",
        "    return F.cross_entropy(output, labels)\n",
        "\n",
        "\n",
        "def train_model(\n",
        "    model,\n",
        "    train_loader,\n",
        "    val_loader,\n",
        "    optimizer,\n",
        "    lr_scheduler,\n",
        "    num_epochs,\n",
        "    loss_fn=loss_fn_default,\n",
        "    print_freq=25,\n",
        "    ckpt_path=\"temp_saved_model.pth\",\n",
        "):\n",
        "    \"\"\"Train the given model with provided parameters.\n",
        "\n",
        "    loss_fn: function that takes model, output, labels and returns loss. This allows us to obtain the loss\n",
        "        from the model as well if needed.\n",
        "    \"\"\"\n",
        "    best_val_acc, best_ep = 0.0, 0\n",
        "    print(f\"Training the model for {num_epochs} epochs...\")\n",
        "    for ep in tqdm(range(1, num_epochs + 1)):\n",
        "        train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, lr_scheduler)\n",
        "\n",
        "        val_acc = evaluate(model, val_loader)\n",
        "        if val_acc >= best_val_acc:\n",
        "            best_val_acc, best_ep = val_acc, ep\n",
        "            torch.save(model.state_dict(), ckpt_path)\n",
        "\n",
        "        if ep == 1 or ep % print_freq == 0 or ep == num_epochs:\n",
        "            print(f\"Epoch {ep:3d}\\t Training loss: {train_loss:.4f}\\t Val Accuracy: {val_acc:.2f}%\")\n",
        "\n",
        "    print(\n",
        "        f\"Model Trained! Restoring to parameters that gave best Val Accuracy ({best_val_acc:.2f}% at Epoch {best_ep}).\"\n",
        "    )\n",
        "    model.load_state_dict(torch.load(ckpt_path), device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "You can uncomment the print statement below to see the ResNet20 model details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [],
      "source": [
        "# print(resnet20())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Training a baseline model\n",
        "\n",
        "It should take about 10-30 mins to train depending on your GPU and CPU. We use slightly \n",
        "different training hyperparameters compared to the original setup described in the paper to make the\n",
        "training faster for this tutorial. \n",
        "\n",
        "You can also reduce the `num_epochs` parameter below to make the whole notebook run faster at the cost of accuracy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n",
            "Train: 45000, Val: 5000, Test: 10000\n"
          ]
        }
      ],
      "source": [
        "batch_size = 512\n",
        "num_epochs = 120\n",
        "learning_rate = 0.1 * batch_size / 128\n",
        "weight_decay = 1e-4\n",
        "\n",
        "train_loader, val_loader, test_loader = get_cifar10_dataloaders(batch_size)\n",
        "\n",
        "batch_per_epoch = len(train_loader)\n",
        "warmup_steps = 5 * batch_per_epoch\n",
        "decay_steps = num_epochs * batch_per_epoch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Training the model for 120 epochs...\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch   1\t Training loss: 0.0049\t Val Accuracy: 22.82%\n",
            "Epoch  25\t Training loss: 0.0006\t Val Accuracy: 78.84%\n",
            "Epoch  50\t Training loss: 0.0004\t Val Accuracy: 85.06%\n",
            "Epoch  75\t Training loss: 0.0002\t Val Accuracy: 88.12%\n",
            "Epoch 100\t Training loss: 0.0001\t Val Accuracy: 90.34%\n",
            "Epoch 120\t Training loss: 0.0000\t Val Accuracy: 90.80%\n",
            "Model Trained! Restoring to parameters that gave best Val Accuracy (90.92% at Epoch 119).\n",
            "Test Accuracy of ResNet20: 90.97\n"
          ]
        }
      ],
      "source": [
        "resnet20_model = resnet20()\n",
        "optimizer, lr_scheduler = get_optimizer_scheduler(\n",
        "    resnet20_model, learning_rate, weight_decay, warmup_steps, decay_steps\n",
        ")\n",
        "train_model(\n",
        "    resnet20_model,\n",
        "    train_loader,\n",
        "    val_loader,\n",
        "    optimizer,\n",
        "    lr_scheduler,\n",
        "    num_epochs,\n",
        "    ckpt_path=\"resnet20.pth\",\n",
        ")\n",
        "print(f\"Test Accuracy of ResNet20: {evaluate(resnet20_model, test_loader)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We have now established a baseline model and accuracy that we will be comparing with using Model Optimizer.\n",
        "\n",
        "So far, we have seen a regular PyTorch model trained without anything new. Now, lets optimize the model for our target constraints using Model Optimizer!\n",
        "\n",
        "## FastNAS Pruning with Model Optimizer\n",
        "\n",
        "The Model Optimizer's `modelopt.torch.prune` module provides advanced state-of-the-art pruning algorithms that enable you to search for the best subnet architecture from your provided base model.\n",
        "\n",
        "Model Optimizer can be used in one of the following complementary modes to create a search space for optimizing the model:\n",
        "\n",
        "1. `fastnas`: A pruning method recommended for Computer Vision models. Given a pretrained model, FastNAS finds the subnet which maximizes the score function while meeting the given constraints.\n",
        "1. `mcore_gpt_minitron`: A pruning method developed by NVIDIA Research for pruning GPT-style models in NVIDIA NeMo or Megatron-LM framework that are using Pipeline Parallellism. It uses the activation magnitudes to prune the mlp, attention heads, and GQA query groups.\n",
        "1. `gradnas`: A light-weight pruning method recommended for language models like Hugging Face BERT, GPT-J. It uses the gradient information to prune the model's linear layers and attention heads to meet the given constraints.\n",
        "\n",
        "In this example, we will use the `fastnas` mode to prune the ResNet20 model for CIFAR-10 dataset. Checkout the [Model Optimizer GitHub repository](https://github.com/NVIDIA/TensorRT-Model-Optimizer) for more examples.\n",
        "\n",
        "Let's first use the FastNAS mode to convert a ResNet model and reduce its FLOPs, number of parameters, and latency."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [],
      "source": [
        "import modelopt.torch.opt as mto\n",
        "import modelopt.torch.prune as mtp"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Prune base model and store pruned net\n",
        "\n",
        "Using [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune) you can\n",
        "\n",
        "* generate a search space for pruning from your base model;\n",
        "* prune the model;\n",
        "* obtain a valid pytorch model that can be used for fine-tuning.\n",
        "\n",
        "Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `30e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n",
        "\n",
        "Finally, we can store the pruned architecture and weights using [mto.save](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.save).\n",
        "\n",
        ".. note:: We are optimizing a relatively smaller model here. A finer-grained search could be more effective in such a case. This is why we are specifying custom configs. In general however, it is recommended to convert models with the default config itself."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Profiling the following subnets from the given model: ('min', 'centroid', 'max').\n",
            "--------------------------------------------------------------------------------\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
            ],
            "text/plain": []
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[3m                                                                             \u001b[0m\n",
            "\u001b[3m                              Profiling Results                              \u001b[0m\n",
            "┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
            "┃\u001b[1m \u001b[0m\u001b[1mConstraint  \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmin         \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcentroid    \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax         \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax/min ratio\u001b[0m\u001b[1m \u001b[0m┃\n",
            "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
            "│ flops        │ 24.33M       │ 27.57M       │ 40.55M       │ 1.67          │\n",
            "│ params       │ 90.94K       │ 141.63K      │ 268.35K      │ 2.95          │\n",
            "└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘\n",
            "\u001b[3m                                              \u001b[0m\n",
            "\u001b[3m            Constraints Evaluation            \u001b[0m\n",
            "┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓\n",
            "┃\u001b[1m              \u001b[0m┃\u001b[1m              \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSatisfiable \u001b[0m\u001b[1m \u001b[0m┃\n",
            "┃\u001b[1m \u001b[0m\u001b[1mConstraint  \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\n",
            "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
            "│ flops        │ 30.00M       │ True         │\n",
            "└──────────────┴──────────────┴──────────────┘\n",
            "\n",
            "\n",
            "Search Space Summary:\n",
            "----------------------------------------------------------------------------------------------------\n",
            "  layers.depth                                                                     [9]\n",
            "  layers.0.out_channels                                                            [16]\n",
            "  layers.0.in_channels                                                             [3]\n",
            "  layers.3.depth                                                                   [3]\n",
            "  layers.3.0.conv1.out_channels                                                    [16]\n",
            "  layers.3.0.conv1.in_channels                                                     [16]\n",
            "  layers.3.0.bn1.num_features                                                      [16]\n",
            "  layers.3.0.conv2.out_channels                                                    [16]\n",
            "  layers.3.0.conv2.in_channels                                                     [16]\n",
            "  layers.3.1.conv1.out_channels                                                    [16]\n",
            "  layers.3.1.conv1.in_channels                                                     [16]\n",
            "  layers.3.1.bn1.num_features                                                      [16]\n",
            "  layers.3.1.conv2.out_channels                                                    [16]\n",
            "  layers.3.1.conv2.in_channels                                                     [16]\n",
            "  layers.3.2.conv1.out_channels                                                    [16]\n",
            "  layers.3.2.conv1.in_channels                                                     [16]\n",
            "  layers.3.2.bn1.num_features                                                      [16]\n",
            "  layers.3.2.conv2.out_channels                                                    [16]\n",
            "  layers.3.2.conv2.in_channels                                                     [16]\n",
            "  layers.4.depth                                                                   [3]\n",
            "* layers.4.0.conv1.out_channels                                                    [16, 32]\n",
            "  layers.4.0.conv1.in_channels                                                     [16]\n",
            "  layers.4.0.bn1.num_features                                                      [16, 32]\n",
            "  layers.4.0.conv2.out_channels                                                    [32]\n",
            "  layers.4.0.conv2.in_channels                                                     [16, 32]\n",
            "* layers.4.1.conv1.out_channels                                                    [16, 32]\n",
            "  layers.4.1.conv1.in_channels                                                     [32]\n",
            "  layers.4.1.bn1.num_features                                                      [16, 32]\n",
            "  layers.4.1.conv2.out_channels                                                    [32]\n",
            "  layers.4.1.conv2.in_channels                                                     [16, 32]\n",
            "* layers.4.2.conv1.out_channels                                                    [16, 32]\n",
            "  layers.4.2.conv1.in_channels                                                     [32]\n",
            "  layers.4.2.bn1.num_features                                                      [16, 32]\n",
            "  layers.4.2.conv2.out_channels                                                    [32]\n",
            "  layers.4.2.conv2.in_channels                                                     [16, 32]\n",
            "  layers.5.depth                                                                   [3]\n",
            "* layers.5.0.conv1.out_channels                                                    [16, 32, 48, 64]\n",
            "  layers.5.0.conv1.in_channels                                                     [32]\n",
            "  layers.5.0.bn1.num_features                                                      [16, 32, 48, 64]\n",
            "  layers.5.0.conv2.out_channels                                                    [64]\n",
            "  layers.5.0.conv2.in_channels                                                     [16, 32, 48, 64]\n",
            "* layers.5.1.conv1.out_channels                                                    [16, 32, 48, 64]\n",
            "  layers.5.1.conv1.in_channels                                                     [64]\n",
            "  layers.5.1.bn1.num_features                                                      [16, 32, 48, 64]\n",
            "  layers.5.1.conv2.out_channels                                                    [64]\n",
            "  layers.5.1.conv2.in_channels                                                     [16, 32, 48, 64]\n",
            "* layers.5.2.conv1.out_channels                                                    [16, 32, 48, 64]\n",
            "  layers.5.2.conv1.in_channels                                                     [64]\n",
            "  layers.5.2.bn1.num_features                                                      [16, 32, 48, 64]\n",
            "  layers.5.2.conv2.out_channels                                                    [64]\n",
            "  layers.5.2.conv2.in_channels                                                     [16, 32, 48, 64]\n",
            "----------------------------------------------------------------------------------------------------\n",
            "Number of configurable hparams: 6\n",
            "Total size of the search space: 5.12e+02\n",
            "Note: all constraints can be satisfied within the search space!\n",
            "\n",
            "\n",
            "Beginning pre-search estimation. If the runtime of score function is longer than a few minutes, consider subsampling the dataset used in score function. \n",
            "A PyTorch dataset can be subsampled using torch.utils.data.Subset (https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) as following:\n",
            " subset_dataset = torch.utils.data.Subset(dataset, indices)\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Collecting pre-search statistics: 100%|██████████| 18/18 [00:10<00:00,  1.76it/s, cur=layers.5.2.conv1.out_channels(64/64): 0.00] \n",
            "[num_satisfied] = 11:   0%|          | 20/10000 [00:02<17:43,  9.39it/s]  "
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Test Accuracy of Pruned ResNet20: 60.37\n"
          ]
        }
      ],
      "source": [
        "# config with more fine-grained channel choices for fastnas\n",
        "config = mtp.fastnas.FastNASConfig()\n",
        "config[\"nn.Conv2d\"][\"*\"][\"channel_divisor\"] = 16\n",
        "config[\"nn.BatchNorm2d\"][\"*\"][\"feature_divisor\"] = 16\n",
        "\n",
        "# A single 32x32 image for computing FLOPs\n",
        "dummy_input = torch.randn(1, 3, 32, 32, device=device)\n",
        "\n",
        "\n",
        "# Wrap your original validation function to only take the model as input.\n",
        "# This function acts as the score function to rank models.\n",
        "def score_func(model):\n",
        "    return evaluate(model, val_loader)\n",
        "\n",
        "\n",
        "# prune the model\n",
        "pruned_model, _ = mtp.prune(\n",
        "    model=resnet20(ckpt=\"resnet20.pth\"),\n",
        "    mode=[(\"fastnas\", config)],\n",
        "    constraints={\"flops\": 30e6},\n",
        "    dummy_input=dummy_input,\n",
        "    config={\n",
        "        \"data_loader\": train_loader,\n",
        "        \"score_func\": score_func,\n",
        "        \"checkpoint\": \"modelopt_seaarch_checkpoint_fastnas.pth\",\n",
        "    },\n",
        ")\n",
        "\n",
        "# save the pruned model for future use\n",
        "mto.save(pruned_model, \"modelopt_pruned_model_fastnas.pth\")\n",
        "\n",
        "# evaluate the pruned model\n",
        "print(f\"Test Accuracy of Pruned ResNet20: {evaluate(pruned_model, test_loader)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n",
        "\n",
        "#### Restore the pruned subnet using [mto.restore](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.restore)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {},
      "outputs": [],
      "source": [
        "pruned_model = mto.restore(resnet20(), \"modelopt_pruned_model_fastnas.pth\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Fine-tuning\n",
        "To fine-tune the subnet, you can simply repeat the training pipeline of the original model (1x training time, 0.5x-1x of original learning rate).\n",
        "The fine-tuned model constitutes the final model with the optimal trade-off between accuracy and your provided constraints that is used for deployment.\n",
        "\n",
        "Note that it would take about 5 - 15 mins to train depending on your GPU and CPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Training the model for 120 epochs...\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a5c29f88d1014d1ea86923484ca900e4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "  0%|          | 0/120 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch   1\t Training loss: 0.0011\t Val Accuracy: 79.90%\n",
            "Epoch  25\t Training loss: 0.0004\t Val Accuracy: 82.82%\n",
            "Epoch  50\t Training loss: 0.0003\t Val Accuracy: 87.00%\n",
            "Epoch  75\t Training loss: 0.0002\t Val Accuracy: 88.62%\n",
            "Epoch 100\t Training loss: 0.0000\t Val Accuracy: 90.62%\n",
            "Epoch 120\t Training loss: 0.0000\t Val Accuracy: 90.58%\n",
            "Model Trained! Restoring to parameters that gave best Val Accuracy (90.70% at Epoch 101).\n"
          ]
        }
      ],
      "source": [
        "optimizer, lr_scheduler = get_optimizer_scheduler(\n",
        "    pruned_model, 0.5 * learning_rate, weight_decay, warmup_steps, decay_steps\n",
        ")\n",
        "train_model(\n",
        "    pruned_model,\n",
        "    train_loader,\n",
        "    val_loader,\n",
        "    optimizer,\n",
        "    lr_scheduler,\n",
        "    num_epochs,\n",
        ")\n",
        "# store final model\n",
        "mto.save(pruned_model, \"modelopt_pruned_model_fastnas_finetuned.pth\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Evaluate the searched subnet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Test Accuracy of the fine-tuned pruned net: 90.28\n"
          ]
        }
      ],
      "source": [
        "# you can restore the fine-tuned model from the vanilla model\n",
        "optimized_model = mto.restore(resnet20(), \"modelopt_pruned_model_fastnas_finetuned.pth\")\n",
        "\n",
        "# test the accuracy\n",
        "print(f\"Test Accuracy of the fine-tuned pruned net: {evaluate(optimized_model, test_loader)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "**Conclusion?**\n",
        "\n",
        "The comparison can be summarized as below:\n",
        "\n",
        "| Model           | FLOPs      | Params     | Test Accuracy     |\n",
        "| --------------- | ---------- | ---------- | ----------------- |\n",
        "| ResNet20        | 40.6M      | 268k       | 90.9%             |\n",
        "| FastNAS subnet  | 29.6M      | 174k       | 90.3%             |\n",
        "\n",
        "As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!\n",
        "\n",
        "Next: checkout the [Model Optimizer GitHub repository](https://github.com/NVIDIA/TensorRT-Model-Optimizer) for more examples."
      ]
    }
  ],
  "metadata": {
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.3"
    },
    "vscode": {
      "interpreter": {
        "hash": "8f5b80c279c91e750afc8c5018925c28766b10630af7de40323a8855d2ba7ef2"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
